跳到主要内容

SpringBoot 使用 AOP 做请求限流

除了使用现成的中间件,也可以自己使用 Redis 搭配 AOP 写一个请求限流的小功能,顺便用来复习 AOP 的使用

搭建环境

主要把这两个依赖引入进来

<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>

<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>

<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>

<!-- 引入一个第三方的工具包 -->
<dependency>
<groupId>cn.hutool</groupId>
<artifactId>hutool-all</artifactId>
<version>5.5.9</version>
</dependency>

注意:因为在 Spring Boot 2.x 后底层不再是使用 Jedis ,而是换成了 Lettuce 所以只需配置 Lettuce

spring:
redis:
host: 127.0.0.1
port: 6379
# 连接超时时间(ms)
timeout: 10000
# Redis默认情况下有16个分片,这里配置具体使用的分片,默认是0
database: 0
lettuce:
pool:
# 连接池最大连接数(使用负值表示没有限制) 默认 8
max-active: 100
# 连接池最大阻塞等待时间(使用负值表示没有限制) 默认 -1
max-wait: -1
# 连接池中的最大空闲连接 默认 8
max-idle: 8
# 连接池中的最小空闲连接 默认 0
min-idle: 0
# 注意:别忘了设置缓存用的是 redis
cache:
type: redis

配置 Template

@Configuration
@AutoConfigureAfter(RedisAutoConfiguration.class)
public class RedisRepositoryConfig {

@Bean
public StringRedisTemplate stringRedisTemplate(RedisConnectionFactory redisConnectionFactory) {
StringRedisTemplate template = new StringRedisTemplate();
template.setConnectionFactory(redisConnectionFactory);
return template;
}
}

取得请求对象的工具类

public final class ServletUtils {
private ServletUtils() {}

/**
* Gets current http servlet request.
*
* @return an optional http servlet request
*/
@NonNull
public static Optional<HttpServletRequest> getCurrentRequest() {
return Optional.ofNullable(RequestContextHolder.getRequestAttributes())
.filter(requestAttributes -> requestAttributes instanceof ServletRequestAttributes)
.map(requestAttributes -> ((ServletRequestAttributes) requestAttributes))
.map(ServletRequestAttributes::getRequest);
}

/**
* Gets request ip.
*
* @return ip address or null
*/
@Nullable
public static String getRequestIp() {
// 这个 ServletUtil 工具类是 hutool 提供的
return getCurrentRequest().map(ServletUtil::getClientIP).orElse(null);
}

/**
* Gets request header.
*
* @param header http header name
* @return http header of null
*/
@Nullable
public static String getHeaderIgnoreCase(String header) {
return getCurrentRequest().map(request -> ServletUtil.getHeaderIgnoreCase(request, header)).orElse(null);
}

/**
* Gets request URI
*/
@Nullable
public static String getRequestURI() {
return getCurrentRequest().map(HttpServletRequest::getRequestURI).orElse(null);
}
}

编写限流注解

@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RequestRateLimit {
// 限流级别,这个注解下面定义了
RateLimitEnum limit();
// 时间单位
TimeUnit timeUnit() default TimeUnit.SECONDS;
}

限流级别枚举代码

public enum RateLimitEnum {
/**
* 例如第二个就是一秒内限制 5 次请求
*/
RRLimit_1_1(1, 1),
RRLimit_1_5(5, 1),
RRLimit_1_10(10,1 ),
RRLimit_1_60(60, 1),
;

private final Integer limitMax;
private final Integer expireTime;

RateLimitEnum(final Integer limitMax, final Integer expireTime) {
this.expireTime = expireTime;
this.limitMax = limitMax;
}

/**
* @return 时间限制内可以请求的次数
*/
public Integer getLimitMax() {
return this.limitMax;
}

/**
* @return 时间限制
*/
public Integer getExpireTime() {
return this.expireTime;
}
}

编写 Redis 限流逻辑

主要看这个 overRequestRateLimit 方法,原理就是对 Redis 里面的相应 key 自增,然后设置一个超时时间,在这个超时时间内,如果自增的 key 大于最大的流量限制,则返回 false

@Slf4j
@Component
@RequiredArgsConstructor
public class RedisService {

private final StringRedisTemplate stringRedisTemplate;

/**
* 是否被限制请求,这里利用 Redis 的自增功能来判断当前时间段,该请求访问了多少次
*
* @param key key
* @param expireTime 限制时间
* @param max 最大请求
* @param timeUnit 时间单位
* @param userAgent 标识
* @return 当前是否被限制请求
*/
public boolean overRequestRateLimit(@NonNull String key, final int expireTime, final int max,
@NonNull TimeUnit timeUnit, String userAgent) {
// 使用断言,这里 key 得非空
Assert.hasText(key, "redis key must not be blank");
// 当前请求访问了多少次
long count = increment(key, 1);
// 当前距离过期还有多长时间,过期了会返回 -1
long time = Optional.ofNullable(stringRedisTemplate.getExpire(key)).orElse(-1L);

/*
* 如果 count == 1 或者 time == -1 表示这个请求要重新开始计数了
*/
if (count == 1 || time == -1) {
expire(key, expireTime, timeUnit);
}

log.debug("UT api request limit rate:too many requests: key={}, redis count={}, max count={}, " +
"expire time= {} s, user-agent={} ", key, count, max, expireTime, userAgent);
// 判断是否大于这个阈值
return count > max;
}

/**
* Redis 这个 key 自增
*
* @param key key
* @param inc 自增增量
* @return key 当前的值
*/
private long increment(String key, int inc) {
return Optional.ofNullable(stringRedisTemplate.opsForValue().increment(key, inc)).orElseThrow(() -> new RedisException("自增失败"));
}

/**
* 设置过期时间
* @param key key
* @param expireTime 过期时间
* @param timeUnit 时间单位
*/
private void expire(String key, int expireTime, TimeUnit timeUnit) {
stringRedisTemplate.expire(key, expireTime, timeUnit);
}
}

对注解方法切面

对加入了 @RequestRateLimit 注解的方法进行切面,使之限流

@Component
@Aspect
@Slf4j
@RequiredArgsConstructor(onConstructor = @__(@Autowired))
public class RequestRateLimitAspect {

private final RedisService redisService;

// 切面设为全部加上 @RequestRateLimit 注解的方法
@Around("@annotation(com.alsritter.requestlimit.annotation.RequestRateLimit)")
public Object requestRateLimit(ProceedingJoinPoint point) throws Throwable {
// Gets request URI
String requestURI = ServletUtils.getRequestURI();
// Gets user-agent from header
String userAgent = ServletUtils.getHeaderIgnoreCase("user-agent");
// Gets IP
String requestIp = ServletUtils.getRequestIp();

// Gets method
final Method method = ((MethodSignature) point.getSignature()).getMethod();
// Gets annotation
RequestRateLimit requestRateLimit = method.getAnnotation(RequestRateLimit.class);
// Gets annotation params
RateLimitEnum limitEnum = requestRateLimit.limit();
TimeUnit timeUnit = requestRateLimit.timeUnit();

// Generates key
String key = String.format("api_request_limit_rate_%s_%s_%s", requestIp, method.getName(), requestURI);
// Checks
boolean over = redisService.overRequestRateLimit(key, limitEnum.getExpireTime(), limitEnum.getLimitMax(), timeUnit, userAgent);
if (over) {
throw new RuntimeException("请求过于频繁,请稍后重试。");
}

return point.proceed();
}
}

编写测试 Controller

这里给这个方法加入一个 @RequestRateLimit 注解,使之能被切面扫描到

@RestController
public class TestController {
private final DateTimeFormatter df = DateTimeFormatter.ofPattern("yyyy-MM-dd hh:mm:ss");

/**
* 令这个接口一秒只能访问 1 次
*/
@RequestRateLimit(limit = RateLimitEnum.RRLimit_1_1)
@GetMapping("/getDate")
public ResponseEntity<String> getDate() {
return ResponseEntity.ok(df.format(LocalDateTime.now()));
}
}

这样,当前请求数量过大(1 秒内大于 1)就会预警

不只是可以使用在 Controller 上面,还可以用在普通方法上面

补充:代理的调用方式

但是注意!!!这样直接在类内部调用自己的方法是不会触发切面的

@RestController
public class TestController {
private final DateTimeFormatter df = DateTimeFormatter.ofPattern("yyyy-MM-dd hh:mm:ss");

@GetMapping("/getDate")
public ResponseEntity<String> getDate() {
// 注意:这样直接在类内部调用 getDateStr 是不会触发切面的
String dateStr = getDateStr();
return ResponseEntity.ok(dateStr);
}

/**
* 令这个方法一秒只能访问 5 次
*/
@RequestRateLimit(limit = RateLimitEnum.RRLimit_1_1)
public String getDateStr() {
return df.format(LocalDateTime.now());
}
}

因为 AOP 会生成一个代理对象,具体查看 Spring AOP 自调 参考 完整剖析SpringAOP的自调用

这里使用 AopContext.currentProxy() 获取代理对象(其实直接使用 Bean 也是可以的,但是看起来太麻烦了)

首先在启动类上加上这个 @EnableAspectJAutoProxy 注解,允许代码中获取 proxy 类:

@SpringBootApplication
@EnableAspectJAutoProxy(exposeProxy = true)
public class RequestLimitApplication {
...

然后使用这个 AopContext.currentProxy() 取得当前代理对象

@RestController
public class TestController {
private final DateTimeFormatter df = DateTimeFormatter.ofPattern("yyyy-MM-dd hh:mm:ss");

@GetMapping("/getDate")
public ResponseEntity<String> getDate() {
// 应该改成如下的方式:
String dateStr = ((TestController)AopContext.currentProxy()).getDateStr();
return ResponseEntity.ok(dateStr);
}

/**
* 令这个方法一秒只能访问 5 次
*/
@RequestRateLimit(limit = RateLimitEnum.RRLimit_1_1)
public String getDateStr() {
return df.format(LocalDateTime.now());
}
}

再次使用就可以使用到 AOP 了

其实它内部和那个 RequestContextHolder.getRequestAttributes() 差不多,都是使用 ThreadLocal 对象